{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e0940c21-0473-456c-81f4-beb53f731dcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch.optim import Adam\n",
    "import torch.utils.data as data\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "import torchgan\n",
    "from torchgan.models import DCGANGenerator, DCGANDiscriminator\n",
    "from torchgan.losses import MinimaxGeneratorLoss, MinimaxDiscriminatorLoss\n",
    "from torchgan.trainer import Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ae1addab-b4d0-445f-b07e-e308697ab8bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cifar10_dataloader():\n",
    "    train_dataset = dsets.CIFAR10(root='./cifar10', train=True,\n",
    "                                  transform=transforms.Compose([transforms.ToTensor(),\n",
    "                                    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))]),\n",
    "                                  download=True)\n",
    "    train_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
    "    return train_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "555ba11f-c6bc-473f-9a73-8cbabbb7f983",
   "metadata": {},
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "Torch not compiled with CUDA enabled",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgenerator\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: DCGANGenerator, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mout_channels\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m3\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstep_channels\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m16\u001b[39m}, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moptimizer\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: Adam, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0.0002\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbetas\u001b[39m\u001b[38;5;124m\"\u001b[39m: (\u001b[38;5;241m0.5\u001b[39m, \u001b[38;5;241m0.999\u001b[39m)}}},\n\u001b[1;32m      2\u001b[0m                    \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdiscriminator\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: DCGANDiscriminator, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124min_channels\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m3\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstep_channels\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m16\u001b[39m}, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moptimizer\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: Adam, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m\"\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0.0002\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbetas\u001b[39m\u001b[38;5;124m\"\u001b[39m: (\u001b[38;5;241m0.5\u001b[39m, \u001b[38;5;241m0.999\u001b[39m)}}}},\n\u001b[1;32m      3\u001b[0m                   [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()],\n\u001b[1;32m      4\u001b[0m                   sample_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m, epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m20\u001b[39m)\n\u001b[1;32m      6\u001b[0m trainer(cifar10_dataloader())\n",
      "File \u001b[0;32m~/miniconda3/envs/DL/lib/python3.11/site-packages/torchgan/trainer/trainer.py:104\u001b[0m, in \u001b[0;36mTrainer.__init__\u001b[0;34m(self, models, losses_list, metrics_list, device, ncritic, epochs, sample_size, checkpoints, retain_checkpoints, recon, log_dir, test_noise, nrow, **kwargs)\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_names\u001b[38;5;241m.\u001b[39mappend(key)\n\u001b[1;32m    102\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m model:\n\u001b[1;32m    103\u001b[0m     \u001b[38;5;28msetattr\u001b[39m(\n\u001b[0;32m--> 104\u001b[0m         \u001b[38;5;28mself\u001b[39m, key, (model[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m](\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m\"\u001b[39m]))\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m    105\u001b[0m     )\n\u001b[1;32m    106\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    107\u001b[0m     \u001b[38;5;28msetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, key, (model[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m]())\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice))\n",
      "File \u001b[0;32m~/miniconda3/envs/DL/lib/python3.11/site-packages/torch/nn/modules/module.py:1160\u001b[0m, in \u001b[0;36mModule.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1156\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(device, dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1157\u001b[0m                     non_blocking, memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format)\n\u001b[1;32m   1158\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(device, dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, non_blocking)\n\u001b[0;32m-> 1160\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_apply(convert)\n",
      "File \u001b[0;32m~/miniconda3/envs/DL/lib/python3.11/site-packages/torch/nn/modules/module.py:810\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    808\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m    809\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 810\u001b[0m         module\u001b[38;5;241m.\u001b[39m_apply(fn)\n\u001b[1;32m    812\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m    813\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m    814\u001b[0m         \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m    815\u001b[0m         \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    820\u001b[0m         \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m    821\u001b[0m         \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/DL/lib/python3.11/site-packages/torch/nn/modules/module.py:810\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    808\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m    809\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 810\u001b[0m         module\u001b[38;5;241m.\u001b[39m_apply(fn)\n\u001b[1;32m    812\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m    813\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m    814\u001b[0m         \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m    815\u001b[0m         \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    820\u001b[0m         \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m    821\u001b[0m         \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/DL/lib/python3.11/site-packages/torch/nn/modules/module.py:810\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    808\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m    809\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 810\u001b[0m         module\u001b[38;5;241m.\u001b[39m_apply(fn)\n\u001b[1;32m    812\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m    813\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m    814\u001b[0m         \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m    815\u001b[0m         \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    820\u001b[0m         \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m    821\u001b[0m         \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/DL/lib/python3.11/site-packages/torch/nn/modules/module.py:833\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    829\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m    830\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m    831\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m    832\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 833\u001b[0m     param_applied \u001b[38;5;241m=\u001b[39m fn(param)\n\u001b[1;32m    834\u001b[0m should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m    835\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m should_use_set_data:\n",
      "File \u001b[0;32m~/miniconda3/envs/DL/lib/python3.11/site-packages/torch/nn/modules/module.py:1158\u001b[0m, in \u001b[0;36mModule.to.<locals>.convert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m   1155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m   1156\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(device, dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1157\u001b[0m                 non_blocking, memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format)\n\u001b[0;32m-> 1158\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(device, dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, non_blocking)\n",
      "File \u001b[0;32m~/miniconda3/envs/DL/lib/python3.11/site-packages/torch/cuda/__init__.py:289\u001b[0m, in \u001b[0;36m_lazy_init\u001b[0;34m()\u001b[0m\n\u001b[1;32m    284\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    285\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot re-initialize CUDA in forked subprocess. To use CUDA with \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    286\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmultiprocessing, you must use the \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mspawn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m start method\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    287\u001b[0m     )\n\u001b[1;32m    288\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(torch\u001b[38;5;241m.\u001b[39m_C, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_cuda_getDeviceCount\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 289\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTorch not compiled with CUDA enabled\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    290\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _cudart \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    291\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\n\u001b[1;32m    292\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlibcudart functions unavailable. It looks like you have a broken build?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    293\u001b[0m     )\n",
      "\u001b[0;31mAssertionError\u001b[0m: Torch not compiled with CUDA enabled"
     ]
    }
   ],
   "source": [
    "trainer = Trainer({\"generator\": {\"name\": DCGANGenerator, \"args\": {\"out_channels\": 3, \"step_channels\": 16}, \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0002, \"betas\": (0.5, 0.999)}}},\n",
    "                   \"discriminator\": {\"name\": DCGANDiscriminator, \"args\": {\"in_channels\": 3, \"step_channels\": 16}, \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0002, \"betas\": (0.5, 0.999)}}}},\n",
    "                  [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()],\n",
    "                  sample_size=64, epochs=20)\n",
    "\n",
    "trainer(cifar10_dataloader())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
